# This program is free software, you can redistribute it and/or modify it.
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This file is a part of the CANN Open Software.
# Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

import os
import shutil
import subprocess
from setuptools import setup, find_packages, Distribution, Command
from wheel.bdist_wheel import bdist_wheel

# package_name and version
PACKAGE_NAME = "npu_math_extension"
VERSION = "1.0.0"


class CleanCommand(Command):
    """
    usage: python3 setup.py clean
    """
    user_options = []
    
    def initialize_options(self):
        pass

    def finalize_options(self):
        pass

    def run(self):
        dirs_to_remove = ['build', 'dist']
        for directory_path in os.listdir('.'):
            if directory_path.endswith('.egg-info') and os.path.isdir(directory_path):
                dirs_to_remove.append(directory_path)
        so_file = os.path.join(PACKAGE_NAME, "_C.abi3.so")
        if os.path.exists(so_file):
            print(f"Removing generated library: {so_file}")
            os.remove(so_file)
        for directory_path in dirs_to_remove:
            print(f"Removing directory: {directory_path}")
            shutil.rmtree(directory_path)


class BinaryDistribution(Distribution):
    """
    Make this wheel not a pure python package
    """
    def has_ext_modules(self):
        return True

    def is_pure(self):
        return False


class CMakeBuild(Command):
    """
    CMake config and build
    """
    description = "Build the C++ extension with CMake"
    user_options = []

    def initialize_options(self):
        pass

    def finalize_options(self):
        pass

    def run(self):
        current_working_directory = os.path.abspath(os.path.dirname(__file__))
        project_root = os.path.abspath(os.path.join(current_working_directory, "../../"))
        build_dir = os.path.join(current_working_directory, "build")
        print(f"-- Project Root: {project_root}")
        print(f"-- Build Dir:    {build_dir}")
        experimental_option = os.getenv("ENABLE_EXPERIMENTAL")

        cmake_config_cmd = [
            "cmake",
            "-S", project_root,
            "-B", build_dir,
            "-DENABLE_TORCH_EXTENSION=ON",
            f"-DENABLE_EXPERIMENTAL={experimental_option}",
            "-G", "Ninja",
            "-DCMAKE_BUILD_TYPE=Release"
        ]

        cmake_build_cmd = [
            "cmake",
            "--build", build_dir,
            "--config", "Release",
            "--target", "_C"
        ]

        print(f"Running CMake Config: {' '.join(cmake_config_cmd)}")
        subprocess.check_call(cmake_config_cmd, cwd=current_working_directory)

        print(f"Running CMake Build: {' '.join(cmake_build_cmd)}")
        subprocess.check_call(cmake_build_cmd, cwd=current_working_directory)


class ABI3Wheel(bdist_wheel):
    """
    Force to use abi3 tag
    """
    def run(self):
        self.run_command('cmake_build')
        super().run()
    
    def get_tag(self):
        python, abi, plat = super().get_tag()
        if python.startswith("cp"):
            return "cp38", "abi3", plat
        return python, abi, plat
    

cmdclass = {
    "clean": CleanCommand,
    "bdist_wheel": ABI3Wheel,
    "cmake_build": CMakeBuild
}


setup(
    name=PACKAGE_NAME,
    version=VERSION,
    description="NPU Math Extension",
    packages=find_packages(),
    package_data={
        PACKAGE_NAME: ["_C.abi3.so"],
    },
    distclass=BinaryDistribution,
    cmdclass=cmdclass,
    zip_safe=False,
    install_requires=[
        "torch",
    ],
    # meta-data
    classifiers=[
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: Python :: 3.12",
        "Programming Language :: Python :: 3.13",
        "Operating System :: POSIX :: Linux",
    ],
    python_requires=">=3.8",
)
